{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f6b9237597a4abfb261ef0a20358b52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b789dd073f54d168c94a966cc147ef6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config_sentence_transformers.json:   0%|          | 0.00/266 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39239f91dc824c1dbb3f807a5c3f3f54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/114k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10fd5a1c643f4bfdbc61c8b2f6f8a8da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4df8a5251a949f5bd95bb83721be8b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/677 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24efdf4958424fcba7b977b98126478e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/670M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00a061f690e5469eb55d2c3317c27954",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/1.24k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b07064692974f1ea96dc33c094feb71",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "634dc6841e4b41b6b2f46149f6032bb7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "556c3e2cad1e46ff8603694e97c1e310",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a28130ede9e4691859435ef8fab808f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "1_Pooling/config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import numpy\n",
    "from pyspark import SparkConf\n",
    "from sentence_transformers import SentenceTransformer, SimilarityFunction\n",
    "from pyspark.sql import SparkSession\n",
    "from aips import get_engine\n",
    "\n",
    "engine = get_engine()\n",
    "\n",
    "#Recommended for making processing faster, if you have enough memory / cores allocated to docker\n",
    "conf = SparkConf()\n",
    "conf.set(\"spark.driver.memory\", \"8g\")\n",
    "conf.set(\"spark.executor.memory\", \"8g\")\n",
    "conf.set(\"spark.dynamicAllocation.enabled\", \"true\")\n",
    "conf.set(\"spark.dynamicAllocation.executorMemoryOverhead\", \"8g\")\n",
    "spark = SparkSession.builder.appName(\"AIPS-ch13\").config(conf=conf).getOrCreate()\n",
    "\n",
    "model = SentenceTransformer(\"mixedbread-ai/mxbai-embed-large-v1\",\n",
    "                            similarity_fn_name=SimilarityFunction.DOT_PRODUCT,\n",
    "                            truncate_dim=1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Some helpful documentation and reference material\n",
    "\n",
    "#https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing\n",
    "#https://github.com/facebookresearch/faiss/wiki\n",
    "#https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py\n",
    "\n",
    "#multi threading sentence_transformers_model\n",
    "#model.stop_multi_process_pool(pool)\n",
    "#pool = model.start_multi_process_pool()\n",
    "#embeddings = model.encode(texts, convert_to_tensor=False).tolis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Already up to date.\n",
      "README.md\n",
      "concepts.pickle\n",
      "._guesses.csv\n",
      "guesses.csv\n",
      "._guesses_all.json\n",
      "guesses_all.json\n",
      "outdoors_concepts.pickle\n",
      "outdoors_embeddings.pickle\n",
      "._outdoors_golden_answers.csv\n",
      "outdoors_golden_answers.csv\n",
      "._outdoors_golden_answers.xlsx\n",
      "outdoors_golden_answers.xlsx\n",
      "._outdoors_golden_answers_20210130.csv\n",
      "outdoors_golden_answers_20210130.csv\n",
      "outdoors_labels.pickle\n",
      "outdoors_question_answering_contexts.json\n",
      "outdoors_questionanswering_test_set.json\n",
      "outdoors_questionanswering_train_set.json\n",
      "._posts.csv\n",
      "posts.csv\n",
      "predicates.pickle\n",
      "pull_aips_dependency.py\n",
      "._question-answer-seed-contexts.csv\n",
      "question-answer-seed-contexts.csv\n",
      "question-answer-squad2-guesses.csv\n",
      "._roberta-base-squad2-outdoors\n",
      "roberta-base-squad2-outdoors/\n",
      "roberta-base-squad2-outdoors/._tokenizer_config.json\n",
      "roberta-base-squad2-outdoors/tokenizer_config.json\n",
      "roberta-base-squad2-outdoors/._special_tokens_map.json\n",
      "roberta-base-squad2-outdoors/special_tokens_map.json\n",
      "roberta-base-squad2-outdoors/._config.json\n",
      "roberta-base-squad2-outdoors/config.json\n",
      "roberta-base-squad2-outdoors/._merges.txt\n",
      "roberta-base-squad2-outdoors/merges.txt\n",
      "roberta-base-squad2-outdoors/._training_args.bin\n",
      "roberta-base-squad2-outdoors/training_args.bin\n",
      "roberta-base-squad2-outdoors/._pytorch_model.bin\n",
      "roberta-base-squad2-outdoors/pytorch_model.bin\n",
      "roberta-base-squad2-outdoors/._vocab.json\n",
      "roberta-base-squad2-outdoors/vocab.json\n"
     ]
    }
   ],
   "source": [
    "![ ! -d 'outdoors' ] && git clone --depth=1 https://github.com/ai-powered-search/outdoors.git\n",
    "! cd outdoors && git pull\n",
    "! cd outdoors && cat outdoors.tgz.part* > outdoors.tgz\n",
    "! cd outdoors && mkdir -p '../data/outdoors/' && tar -xvf outdoors.tgz -C '../data/outdoors/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings(texts, model, cache_name, ignore_cache=False):\n",
    "    cache_file_name = f\"data/embeddings/{cache_name}.pickle\"\n",
    "    if ignore_cache or not os.path.isfile(cache_file_name):        \n",
    "        embeddings = model.encode(texts, normalize_embeddings=True)\n",
    "        os.makedirs(os.path.dirname(cache_file_name), exist_ok=True)\n",
    "        with open(cache_file_name, \"wb\") as fd:\n",
    "            pickle.dump(embeddings, fd)\n",
    "    else:\n",
    "        with open(cache_file_name, \"rb\") as fd:\n",
    "            embeddings = pickle.load(fd)\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Boilerplate code for Quantization listings\n",
    "### Generating embeddings and benchmark data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import faiss\n",
    "from aips.data_loaders.outdoors import load_dataframe\n",
    "\n",
    "def display_results(scores, ids, data):\n",
    "    results = generate_search_results(scores, ids, data)\n",
    "    display(results)\n",
    "    return results\n",
    "\n",
    "def get_outdoors_data():\n",
    "    outdoors_dataframe = load_dataframe(\"data/outdoors/posts.csv\")\n",
    "    outdoors_data = list(outdoors_dataframe.rdd.map(lambda r: r.asDict()).collect())\n",
    "    return outdoors_data\n",
    "\n",
    "def display_statistics(search_results, baseline_search_results=None, start_message=\"Recall\"):\n",
    "    index_name = search_results[\"index_name\"]\n",
    "    time_taken = search_results[\"time_taken\"]\n",
    "    index_size = search_results[\"size\"]      \n",
    "    improvement_ms = \"\"\n",
    "    improvement_size = \"\"\n",
    "    recall = 1.0\n",
    "    if baseline_search_results:\n",
    "        full_search_time = baseline_search_results[\"time_taken\"]\n",
    "        time_imp = round((full_search_time - time_taken) * 100 / full_search_time, 2)\n",
    "        improvement_ms = f\" ({time_imp}% improvement)\"\n",
    "        improvement_size = f\" ({round((baseline_search_results['size'] - index_size) * 100 / baseline_search_results['size'], 2)}% improvement)\"\n",
    "        recall = calculate_recall(baseline_search_results[\"results\"], search_results[\"results\"])\n",
    "            \n",
    "    print(f\"{index_name} search took: {time_taken:.3f} ms{improvement_ms}\")\n",
    "    print(f\"{index_name} index size: {round(index_size / 1000000, 2)} MB{improvement_size}\")\n",
    "    print(f\"{start_message}: {round(recall, 4)}\")\n",
    "\n",
    "def calculate_recall(scored_full_results, scored_quantized_results):\n",
    "    recalls = []\n",
    "    for i in range(len(scored_full_results)):\n",
    "        full_ids = [r[\"id\"] for r in scored_full_results[i]]\n",
    "        quantized_ids = [r[\"id\"] for r in scored_quantized_results[i]]\n",
    "        recalls.append((len(set(full_ids).intersection(set(quantized_ids))) /\n",
    "                       len(set(quantized_ids))))\n",
    "    return sum(recalls) / len(recalls)\n",
    "\n",
    "def generate_search_results(faiss_scores, faiss_ids):\n",
    "    outdoors_data = get_outdoors_data()\n",
    "    faiss_results = []\n",
    "    for i in range(len(faiss_scores)):\n",
    "        results = []\n",
    "        for j, id in enumerate(faiss_ids[i]):\n",
    "            id = int(id)\n",
    "            result = {\"score\": faiss_scores[i][j],\n",
    "                      \"title\": outdoors_data[id][\"title\"],\n",
    "                      \"body\": outdoors_data[id][\"body\"],\n",
    "                      \"id\": id}\n",
    "            results.append(result)\n",
    "        faiss_results.append(results)\n",
    "    return faiss_results\n",
    "\n",
    "def time_and_execute_search(index, index_name, query_embeddings, k=25, num_runs=100):\n",
    "    search_times = []\n",
    "    faiss_scores = None \n",
    "    faiss_ids = None\n",
    "    \n",
    "    for i in range(num_runs):\n",
    "        start_time = time.time()\n",
    "        faiss_scores, faiss_ids = index.search(query_embeddings, k=k)\n",
    "        time_taken = ((time.time() - start_time) * 1000)\n",
    "        search_times.append(time_taken)\n",
    "\n",
    "    results = {\"results\": generate_search_results(faiss_scores, faiss_ids),\n",
    "               \"time_taken\": numpy.average(search_times), \n",
    "               \"faiss_scores\": faiss_scores, \"faiss_ids\": faiss_ids}\n",
    "    index_stats = {}\n",
    "    if index_name:\n",
    "        index_stats ={\"index_name\": index_name,\n",
    "                      \"size\": os.path.getsize(index_name)}\n",
    "    return results | index_stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.20"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Indexing full-precision embeddings using FAISS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.quantization import quantize_embeddings\n",
    "\n",
    "model = SentenceTransformer(\n",
    "          \"mixedbread-ai/mxbai-embed-large-v1\",\n",
    "          similarity_fn_name=SimilarityFunction.DOT_PRODUCT,\n",
    "          truncate_dim=1024)\n",
    "\n",
    "def index_full_precision_embeddings(doc_embeddings, name):\n",
    "    index = faiss.IndexFlatIP(doc_embeddings.shape[1])\n",
    "    index.add(doc_embeddings)\n",
    "    faiss.write_index(index, name)\n",
    "    return index\n",
    "\n",
    "def get_outdoors_embeddings(model):\n",
    "    outdoors_dataframe = load_dataframe(\"data/outdoors/posts.csv\")\n",
    "    post_texts = [post[\"title\"] + \" \" + post[\"body\"]\n",
    "                  for post in outdoors_dataframe.collect()]\n",
    "    #TODO: This will take 2-5 hours to run the first time if not cached. Upload to github to save readers hassle.\n",
    "    return numpy.array(\n",
    "        get_embeddings(post_texts, model, \"outdoors_mrl_normed\"))\n",
    "\n",
    "doc_embeddings = get_outdoors_embeddings(model) #takes 2-3 hours if not cached\n",
    "full_index = index_full_precision_embeddings(doc_embeddings,\n",
    "                                          \"full_embeddings\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.21"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating full-precision query embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full_embeddings search took: 6.362 ms\n",
      "full_embeddings index size: 75.6 MB\n",
      "Recall: 1.0\n"
     ]
    }
   ],
   "source": [
    "def get_test_queries():\n",
    "    return[\"tent poles\", \"hiking trails\", \"mountain forests\",\n",
    "           \"white water\", \"best waterfalls\", \"mountain biking\",\n",
    "           \"snowboarding slopes\", \"bungee jumping\", \"public parks\"]\n",
    "\n",
    "queries = get_test_queries()\n",
    "query_embeddings = model.encode(queries,\n",
    "                  convert_to_numpy=True,\n",
    "              normalize_embeddings=True)\n",
    "\n",
    "full_results = time_and_execute_search(\n",
    "    full_index, \"full_embeddings\",\n",
    "    query_embeddings, k=25)\n",
    "display_statistics(full_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Listing 13.22"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Functions to benchmark quantized search approaches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_search(full_index, optimized_index, optimized_index_name,\n",
    "                    query_embeddings, optimized_query_embeddings,\n",
    "                    k=25, display=True, log=False):\n",
    "    full_results = time_and_execute_search(full_index, \"full_embeddings\",\n",
    "                                           query_embeddings, k=k)\n",
    "    optimized_results = time_and_execute_search(optimized_index, optimized_index_name,\n",
    "                                                optimized_query_embeddings, k=k)\n",
    "    if display:\n",
    "        display_statistics(optimized_results, full_results)\n",
    "    return optimized_results, full_results\n",
    "\n",
    "def evaluate_rerank_search(full_index, optimized_index,\n",
    "                           query_embeddings,\n",
    "                           optimized_embeddings,\n",
    "                           k=50, limit=25):\n",
    "    results, full_results = evaluate_search(full_index, optimized_index, None, query_embeddings,\n",
    "                                            optimized_embeddings, display=False, k=k)\n",
    "    \n",
    "    doc_embeddings = get_outdoors_embeddings(model) #This can point to a cheap on-disk data source containing the original full-precision embeddings\n",
    "    rescore_scores, rescore_ids = [], []\n",
    "    for i in range(len(results[\"results\"])):\n",
    "        embedding_ids = results[\"faiss_ids\"][i]\n",
    "        top_k_embeddings = [doc_embeddings[id]\n",
    "                            for id in embedding_ids]\n",
    "        query_embedding = query_embeddings[i]\n",
    "        scores = query_embedding @ numpy.array(top_k_embeddings).T\n",
    "        indices = scores.argsort()[::-1][:limit]\n",
    "        top_k_indices = embedding_ids[indices]\n",
    "        top_k_scores = scores[indices]\n",
    "        rescore_scores.append(top_k_scores)\n",
    "        rescore_ids.append(top_k_indices)\n",
    "\n",
    "    results = generate_search_results(rescore_scores, rescore_ids)\n",
    "    recall = calculate_recall(full_results[\"results\"], results)\n",
    "    print(f\"Reranked recall: {round(recall, 4)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.23\n",
    "### int8 quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Int8 embeddings shape: (18456, 1024)\n",
      "int8_embeddings search took: 8.514 ms (-41.59% improvement)\n",
      "int8_embeddings index size: 18.91 MB (74.99% improvement)\n",
      "Recall: 0.9289\n",
      "Reranked recall: 1.0\n"
     ]
    }
   ],
   "source": [
    "def index_int8_embeddings(doc_embeddings, name):\n",
    "    int8_embeddings = quantize_embeddings(doc_embeddings,\n",
    "                                          precision=\"int8\")\n",
    "    print(\"Int8 embeddings shape:\", int8_embeddings.shape)\n",
    "    index = faiss.IndexScalarQuantizer(int8_embeddings.shape[1],\n",
    "                                       faiss.ScalarQuantizer.QT_8bit)\n",
    "    index.train(int8_embeddings)\n",
    "    index.add(int8_embeddings)\n",
    "    faiss.write_index(index, name)\n",
    "    return index\n",
    "\n",
    "int8_index_name = \"int8_embeddings\"\n",
    "int8_index = index_int8_embeddings(doc_embeddings, int8_index_name)\n",
    "\n",
    "quantized_queries = quantize_embeddings(query_embeddings,\n",
    "                   calibration_embeddings=doc_embeddings,\n",
    "                                        precision=\"int8\")\n",
    "\n",
    "evaluate_search(full_index, int8_index, int8_index_name,\n",
    "                query_embeddings, quantized_queries)\n",
    "evaluate_rerank_search(full_index, int8_index,\n",
    "                       query_embeddings, quantized_queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.24\n",
    "### Binary Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Binary embeddings shape: (18456, 128)\n",
      "binary_embeddings search took: 0.352 ms (94.36% improvement)\n",
      "binary_embeddings index size: 2.36 MB (96.87% improvement)\n",
      "Recall: 0.6044\n",
      "Reranked recall: 1.0\n"
     ]
    }
   ],
   "source": [
    "def index_binary_embeddings(doc_embeddings, binary_index_name):\n",
    "    binary_embeddings = quantize_embeddings(doc_embeddings,\n",
    "                                    precision=\"binary\").astype(numpy.uint8)\n",
    "    print(\"Binary embeddings shape:\", binary_embeddings.shape)\n",
    "    index = faiss.IndexBinaryFlat(binary_embeddings.shape[1] * 8)\n",
    "    index.add(binary_embeddings)\n",
    "    faiss.write_index_binary(index, binary_index_name)\n",
    "    return index\n",
    "\n",
    "binary_index_name = \"binary_embeddings\"\n",
    "binary_index = index_binary_embeddings(doc_embeddings, binary_index_name)\n",
    "\n",
    "quantized_queries = quantize_embeddings(query_embeddings,\n",
    "                   calibration_embeddings=doc_embeddings,\n",
    "                                      precision=\"binary\").astype(numpy.uint8)\n",
    "\n",
    "evaluate_search(full_index, binary_index, binary_index_name,\n",
    "                query_embeddings, quantized_queries)\n",
    "evaluate_rerank_search(full_index, binary_index,\n",
    "                       query_embeddings, quantized_queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.25\n",
    "### Product Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pq_embeddings search took: 0.519 ms (90.04% improvement)\n",
      "pq_embeddings index size: 1.34 MB (98.22% improvement)\n",
      "Recall: 0.3333\n",
      "Reranked recall: 0.68\n"
     ]
    }
   ],
   "source": [
    "def index_pq_embeddings(doc_embeddings, index_name, num_subvectors=16):\n",
    "    dimensions = doc_embeddings.shape[1]\n",
    "    M = num_subvectors\n",
    "    num_bits = 8\n",
    "    index = faiss.IndexPQ(dimensions, M, num_bits)\n",
    "    index.train(doc_embeddings)\n",
    "    index.add(doc_embeddings)   \n",
    "    faiss.write_index(index, index_name) # Commit the index to disk\n",
    "    return index\n",
    "\n",
    "pq_index_name = \"pq_embeddings\"\n",
    "pq_index = index_pq_embeddings(doc_embeddings, pq_index_name)\n",
    "\n",
    "evaluate_search(full_index, pq_index, pq_index_name, query_embeddings, query_embeddings)\n",
    "evaluate_rerank_search(full_index, pq_index, query_embeddings, query_embeddings)\n",
    "\n",
    "#TODO: consider adding IVFFlatPQ as optimization for speed to show at end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run again with 64 subvectors instead of 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pq_embeddings search took: 1.768 ms (72.07% improvement)\n",
      "pq_embeddings index size: 2.23 MB (97.05% improvement)\n",
      "Recall: 0.5778\n",
      "Reranked recall: 0.9911\n"
     ]
    }
   ],
   "source": [
    "pq_index = index_pq_embeddings(doc_embeddings, pq_index_name, num_subvectors=64)\n",
    "evaluate_search(full_index, pq_index, pq_index_name, query_embeddings, query_embeddings)\n",
    "evaluate_rerank_search(full_index, pq_index, query_embeddings, query_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.26\n",
    "### Matryoshka Representations Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original embeddings shape: (18456, 1024)\n",
      "mrl_embeddings_512 embeddings shape: (18456, 512)\n",
      "\n",
      "mrl_embeddings_512 search took: 2.868 ms (52.67% improvement)\n",
      "mrl_embeddings_512 index size: 37.8 MB (50.0% improvement)\n",
      "Recall: 0.7022\n",
      "Reranked recall: 1.0\n",
      "mrl_embeddings_256 embeddings shape: (18456, 256)\n",
      "\n",
      "mrl_embeddings_256 search took: 1.197 ms (79.86% improvement)\n",
      "mrl_embeddings_256 index size: 18.9 MB (75.0% improvement)\n",
      "Recall: 0.4756\n",
      "Reranked recall: 0.9689\n",
      "mrl_embeddings_128 embeddings shape: (18456, 128)\n",
      "\n",
      "mrl_embeddings_128 search took: 0.549 ms (91.0% improvement)\n",
      "mrl_embeddings_128 index size: 9.45 MB (87.5% improvement)\n",
      "Recall: 0.2489\n",
      "Reranked recall: 0.64\n"
     ]
    }
   ],
   "source": [
    "def get_mrl_embeddings(embeddings, num_dimensions):\n",
    "    mrl_embeddings = numpy.array(\n",
    "        list(map(lambda e: e[:num_dimensions], embeddings)))\n",
    "    return mrl_embeddings\n",
    "\n",
    "def index_mrl_embeddings(doc_embeddings, num_dimensions, mrl_index_name):\n",
    "    mrl_doc_embeddings = get_mrl_embeddings(doc_embeddings, num_dimensions)\n",
    "    print(f\"{mrl_index_name} embeddings shape:\", mrl_doc_embeddings.shape)\n",
    "    mrl_index = index_full_precision_embeddings(mrl_doc_embeddings,\n",
    "                                                    mrl_index_name)\n",
    "    return mrl_index\n",
    "\n",
    "print(f\"Original embeddings shape: {doc_embeddings.shape}\")\n",
    "original_dimensions = doc_embeddings.shape[1] #1024\n",
    "\n",
    "for num_dimensions in [original_dimensions//2, #512\n",
    "                       original_dimensions//4, #256\n",
    "                       original_dimensions//8]: #128\n",
    "\n",
    "    mrl_index_name = f\"mrl_embeddings_{num_dimensions}\"\n",
    "    mrl_index = index_mrl_embeddings(doc_embeddings, num_dimensions, mrl_index_name)\n",
    "    mrl_queries = get_mrl_embeddings(query_embeddings, num_dimensions)\n",
    "    print(\"\\n\", end=\"\")\n",
    "    \n",
    "    evaluate_search(full_index, mrl_index, mrl_index_name,\n",
    "                            query_embeddings, mrl_queries)\n",
    "    evaluate_rerank_search(full_index, mrl_index,\n",
    "                   query_embeddings, mrl_queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.27"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Combining techniques"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "binary_ivf_mrl_embeddings search took: 0.114 ms (98.06% improvement)\n",
      "binary_ivf_mrl_embeddings index size: 1.35 MB (98.22% improvement)\n",
      "Recall: 0.3511\n",
      "Reranked recall: 0.7244\n"
     ]
    }
   ],
   "source": [
    "def index_binary_ivf_mrl_embeddings(reduced_mrl_doc_embeddings, binary_index_name):    \n",
    "    #Binary quantization\n",
    "    binary_embeddings = quantize_embeddings(reduced_mrl_doc_embeddings,\n",
    "                        calibration_embeddings=reduced_mrl_doc_embeddings,\n",
    "                        precision=\"binary\").astype(numpy.uint8)\n",
    "    dimensions = reduced_mrl_doc_embeddings.shape[1]\n",
    "    quantizer = faiss.IndexBinaryFlat(dimensions)\n",
    "\n",
    "    #ANN: IVF Flat Algorithm\n",
    "    num_clusters = 256\n",
    "    index = faiss.IndexBinaryIVF(\n",
    "        quantizer, dimensions, num_clusters)\n",
    "    index.nprobe = 4\n",
    "\n",
    "    index.train(binary_embeddings)\n",
    "    index.add(binary_embeddings)\n",
    "    faiss.write_index_binary(index, binary_index_name)\n",
    "    return index\n",
    "\n",
    "mrl_dimensions = doc_embeddings.shape[1] // 2  #MRL 1024 => 512 dimensions\n",
    "reduced_mrl_doc_embeddings =  get_mrl_embeddings(\n",
    "    doc_embeddings, mrl_dimensions)\n",
    "\n",
    "binary_ivf_mrl_index_name = \"binary_ivf_mrl_embeddings\"\n",
    "binary_ivf_mrl_index = index_binary_ivf_mrl_embeddings(\n",
    "    reduced_mrl_doc_embeddings, binary_ivf_mrl_index_name)\n",
    "\n",
    "mrl_queries = get_mrl_embeddings(query_embeddings, \n",
    "                                  mrl_dimensions)\n",
    "quantized_queries = quantize_embeddings(mrl_queries,\n",
    "      calibration_embeddings=reduced_mrl_doc_embeddings,\n",
    "      precision=\"binary\").astype(numpy.uint8)\n",
    "\n",
    "evaluate_search(full_index, binary_ivf_mrl_index,\n",
    "    binary_ivf_mrl_index_name, \n",
    "    query_embeddings, quantized_queries)\n",
    "evaluate_rerank_search(\n",
    "    full_index, binary_ivf_mrl_index,\n",
    "    query_embeddings, quantized_queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.28\n",
    "Rerank search results with cross-encoder \n",
    "\n",
    "#### Located in the [Chapter 13 Semantic Search notebook](4.semantic-search.ipynb#listing-13.28) notebook\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Up next: [Chapter 14: Question Answering with a Fine-tuned Large Language Model](../ch14/1.question-answering-visualizer.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
